"""YAML utils."""

import argparse
import copy
import functools
from importlib import resources
import json
import os
import pathlib
import re
import sys
import typing

import jsonschema
import sentry_sdk
import yaml

from . import config_tree
from . import misc
from .session import get_session

SESSION = get_session(__name__, raise_for_status=True)


# pylint: disable=too-many-ancestors
class BlockDumper(yaml.SafeDumper):
    """Block-style literals for strings with newlines."""

    def increase_indent(self, flow: bool = False, indentless: bool = False) -> None:
        """Make yamllint happy with correct list indentation."""
        return super().increase_indent(flow=flow, indentless=False)

    def represent_str(self, data: str) -> yaml.ScalarNode:
        """Format strings with newlines as block-style literals."""
        if '\n' not in data:
            return super().represent_str(data)
        return self.represent_scalar('tag:yaml.org,2002:str', data, style='|')


class ReferenceList(typing.List[typing.Any]):
    """Holds references like !reference [path, to, node]."""

    @staticmethod
    def find_reference(path: 'ReferenceList', root: typing.Any = None) -> typing.Any:
        """Find the node belonging to a !reference [path, to, node]."""
        if not isinstance(root, dict):
            raise Exception('Root node is not a dictionary')
        current = root
        for step in path:
            current = current[step]
        return copy.deepcopy(current)

    @staticmethod
    def resolve_references(data: typing.Any, root: typing.Any = None) -> typing.Any:
        """Resolve references like !reference [path, to, node]."""
        if root is None:
            root = data
        if isinstance(data, dict):
            for key, value in data.items():
                if isinstance(value, ReferenceList):
                    data[key] = ReferenceList.find_reference(value, root)
                else:
                    data[key] = ReferenceList.resolve_references(value, root)
        elif isinstance(data, list):
            for key, value in enumerate(data):
                if isinstance(value, ReferenceList):
                    data[key] = ReferenceList.find_reference(value, root)
                else:
                    data[key] = ReferenceList.resolve_references(value, root)
        return data


class ReferenceLoader(yaml.SafeLoader):
    """Resolve references like !reference [path, to, node]."""

    def construct_reference(self, node: yaml.nodes.Node) -> ReferenceList:
        """Construct a list marked as a reference list."""
        return ReferenceList(self.construct_sequence(node))  # type: ignore


yaml.add_representer(str, BlockDumper.represent_str, Dumper=BlockDumper)
yaml.add_constructor('!reference', ReferenceLoader.construct_reference, Loader=ReferenceLoader)


def _resolve_includes(
        data: typing.Dict[str, typing.Any],
        relative_to: typing.Union[pathlib.Path, str, None],
) -> typing.Dict[str, typing.Any]:
    if (includes := data.pop('.include', None)) is None:
        return data
    root_directory = os.path.dirname(relative_to) if relative_to else os.getcwd()
    result: typing.Dict[typing.Any, typing.Any] = {}
    for include in misc.flattened(includes):
        if include.startswith('https://') or include.startswith('http://'):
            config_tree.merge_dicts(result, load(contents=SESSION.get(include).text))
        else:
            config_tree.merge_dicts(result, load(file_path=f'{root_directory}/{include}'))
    config_tree.merge_dicts(result, data)
    return result


def load(*,
         file_path: pathlib.Path | str | None = None,
         contents: str | None = None,
         resolve_references: bool | None = None,
         resolve_includes: bool | None = None,
         loader: type[yaml.SafeLoader] | None = None,
         schema_path: pathlib.Path | str | None = None,
         format_checker: jsonschema._format.FormatChecker | None = None,
         process_config_tree: bool | None = None,
         load_all: bool = False,
         ) -> typing.Any:
    # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-branches
    """Load a YAML file and optionally validate it with a schema.

    resolve_references: support !reference tag as described at
        <https://docs.gitlab.com/ee/ci/yaml/#reference-tags>
    resolve_includes: support one level of .include: <url|file_name>
    """
    if contents is None and file_path:
        path = pathlib.Path(file_path) if isinstance(file_path, str) else file_path
        contents = path.read_text(encoding='utf8')
    if contents is None:
        return [None] if load_all else None
    results = []
    for result in yaml.load_all(contents, Loader=loader or ReferenceLoader):
        if isinstance(result, dict) and (schema := result.pop('.schema', None)) and not schema_path:
            anchor, resource = schema.split('/', 1)
            schema_path = resources.files(anchor) / resource
        if schema_path:
            schema_path = pathlib.Path(schema_path) if isinstance(schema_path, str) else schema_path
            schema = yaml.safe_load(schema_path.read_text(encoding='utf8'))
            if schema.get('$ckiResolveIncludes') and resolve_includes is None:
                resolve_includes = True
            if schema.get('$ckiResolveReferences') and resolve_references is None:
                resolve_references = True
            if schema.get('$ckiProcessConfigTree') and process_config_tree is None:
                process_config_tree = True
        if resolve_includes and isinstance(result, dict):
            result = _resolve_includes(result, relative_to=file_path)
        if resolve_references:
            result = ReferenceList.resolve_references(result)
        if process_config_tree:
            result = config_tree.process_config_tree(result)
        if schema_path:
            validate(result, schema, format_checker=format_checker)
        if not load_all:
            return result
        results.append(result)
    if not results:
        return [None] if load_all else None
    return results


def dump(data: typing.Any, **kwargs: typing.Any) -> typing.Any:
    """Dump a YAML file."""
    return yaml.dump(data, Dumper=BlockDumper, explicit_start=True, **kwargs)


class ValidationError(jsonschema.exceptions.ValidationError):
    """The data was invalid per the given schema."""

    # The main goal here is having an abstraction over the error raised by
    # the third-party implementation of the JSON schema validation,
    # while making the most out of its __str__() method,
    # which currently provides a lot of detail.

    def __init__(
        self,
        message: str,
        validation_error: jsonschema.exceptions.ValidationError = None,
    ):
        """Initialize error instance.

        Args:
            message: Short explanation of what went wrong.
            validation_error: Optional instance of the error raised by the
                3rd-party JSON schema validation with additional information
        """
        validation_error_kwargs = {}
        if validation_error:
            message = message + " " + validation_error.message
            validation_error_kwargs = {
                "validator": validation_error.validator,
                "path": validation_error.path,
                "cause": validation_error.cause,
                "context": validation_error.context,
                "validator_value": validation_error.validator_value,
                "instance": validation_error.instance,
                "schema": validation_error.schema,
                "schema_path": validation_error.schema_path,
                "parent": validation_error.parent,
                "type_checker": validation_error._type_checker,
            }

        super().__init__(message, **validation_error_kwargs)


YamlData = typing.Union[typing.Dict, typing.List, str]


def validate(
    instance: YamlData,
    schema: YamlData,
    format_checker: typing.Optional[jsonschema._format.FormatChecker] = None,
    error_str: str = 'Yaml is not valid.'
) -> YamlData:
    """Use jsonschema.validate to validate the given yaml data intance."""
    try:
        jsonschema.validate(instance=instance, schema=schema, format_checker=format_checker)
    except jsonschema.exceptions.ValidationError as validation_err:
        raise ValidationError(error_str, validation_err) from None
    return instance


def _validate_cli_one(
    **kwargs: typing.Any,
) -> int:
    """Validate one file, and return the appropriate exit code."""
    try:
        load(**kwargs)
    except ValidationError as e:
        print(f'{kwargs["file_path"]}: {e}')
        return 1
    return 0


def _validate_cli(parsed_args: typing.Any) -> None:
    """Validate files with a JSON schema, and sys.exit() appropriately."""
    if not parsed_args.files:
        load_args = [{'file_path': 'stdin', 'contents': sys.stdin.read()}]
    else:
        load_args = [{'file_path': f} for f in parsed_args.files]
    sys.exit(max(_validate_cli_one(
        process_config_tree=parsed_args.process_config_tree,
        resolve_includes=parsed_args.resolve_includes,
        resolve_references=parsed_args.resolve_references,
        schema_path=parsed_args.schema,
        **a,
    ) for a in load_args))


def _split_dot(data: str) -> list[str]:
    """Split string on dots, but ignore backslash+dot."""
    return list(t.replace(r'\.', '.') for t in re.split(r'(?<!\\)\.', data))


def _key(target: typing.Any, key: str) -> str | int:
    """Prepare a key for list access if necessary."""
    return int(key) if isinstance(target, list) else key


def _dump(data: typing.Any, output_format: str) -> None:
    if output_format == 'yaml':
        dump(data, stream=sys.stdout)
    else:
        json.dump(data, sys.stdout)


def _set_value_cli(parsed_args: typing.Any) -> None:
    data = load(contents=sys.stdin.read())
    *keys, last = _split_dot(parsed_args.key)
    target = functools.reduce(lambda d, k: d[_key(d, k)], keys, data)
    target[_key(target, last)] = load(contents=parsed_args.value)
    _dump(data, parsed_args.format)


def _del_cli(parsed_args: typing.Any) -> None:
    data = load(contents=sys.stdin.read())
    *keys, last = _split_dot(parsed_args.key)
    target = functools.reduce(lambda d, k: d[_key(d, k)], keys, data)
    del target[_key(target, last)]
    _dump(data, parsed_args.format)


def _dump_cli(parsed_args: typing.Any) -> None:
    data = load(contents=sys.stdin.read())
    _dump(data, parsed_args.format)


def main(args: list[str] | None = None) -> None:
    """Run schema validation of the given file."""
    parser = argparse.ArgumentParser(description='Work with YAML files')

    subparsers = parser.add_subparsers(dest='action', required=True)

    parser_validate = subparsers.add_parser(
        'validate', help='Validate YAML files', description='Validate YAML files')
    parser_validate.set_defaults(func=_validate_cli)
    parser_validate.add_argument('files', nargs='*',
                                 help='paths of YAML files to validate')
    parser_validate.add_argument('--schema',
                                 help='path of JSON schema file')
    parser_validate.add_argument('--process-config-tree', action='store_true', default=None,
                                 help='support .default and .extends')
    parser_validate.add_argument('--resolve-includes', action='store_true', default=None,
                                 help='support .include')
    parser_validate.add_argument('--resolve-references', action='store_true', default=None,
                                 help='support !reference tags')

    parser_set_value = subparsers.add_parser(
        'set-value', help='Set a dot-delimited key', description='Set a dot-delimited key')
    parser_set_value.set_defaults(func=_set_value_cli)
    parser_set_value.add_argument('key', help='dot-delimited key')
    parser_set_value.add_argument('value', help='YAML-formatted value')
    parser_set_value.add_argument('--format', choices=('yaml', 'json'), default='yaml',
                                  help='output format (default: yaml)')

    parser_del = subparsers.add_parser(
        'del', help='Delete a dot-delimited key', description='Delete a dot-delimited key')
    parser_del.set_defaults(func=_del_cli)
    parser_del.add_argument('key', help='dot-delimited key')
    parser_del.add_argument('--format', choices=('yaml', 'json'), default='yaml',
                            help='output format (default: yaml)')

    parser_dump = subparsers.add_parser(
        'dump', help='Dump a YAML data', description='Dump YAML data')
    parser_dump.set_defaults(func=_dump_cli)
    parser_dump.add_argument('--format', choices=('yaml', 'json'), default='yaml',
                             help='output format (default: yaml)')

    parsed_args = parser.parse_args(args)

    parsed_args.func(parsed_args)


if __name__ == '__main__':
    misc.sentry_init(sentry_sdk)
    main()
